from collections import defaultdict
import logging
from pathlib import Path
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple

import numpy as np
import torch
from torch.nn.parallel import data_parallel
from torch.utils.data import DataLoader
from typeguard import check_argument_types

from espnet2.fileio.datadir_writer import DatadirWriter
from espnet2.fileio.npy_scp import NpyScpWriter
from espnet2.torch_utils.device_funcs import to_device
from espnet2.torch_utils.forward_adaptor import ForwardAdaptor
from espnet2.train.abs_espnet_model import AbsESPnetModel


@torch.no_grad()
def collect_stats(
    model: AbsESPnetModel,
    train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
    valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
    output_dir: Path,
    ngpu: Optional[int],
    log_interval: Optional[int],
    write_collected_feats: bool,
) -> None:
    """Perform on collect_stats mode.

    Running for deriving the shape information from data
    and gathering statistics.
    This method is used before executing train().

    """
    assert check_argument_types()

    npy_scp_writers = {}
    for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]):
        if log_interval is None:
            try:
                log_interval = max(len(itr) // 20, 10)
            except TypeError:
                log_interval = 100

        sum_dict = defaultdict(lambda: 0)
        sq_dict = defaultdict(lambda: 0)
        count_dict = defaultdict(lambda: 0)

        with DatadirWriter(output_dir / mode) as datadir_writer:
            for iiter, (keys, batch) in enumerate(itr, 1):
                batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")

                # 1. Write shape file
                for name in batch:
                    if name.endswith("_lengths"):
                        continue
                    for i, (key, data) in enumerate(zip(keys, batch[name])):
                        if f"{name}_lengths" in batch:
                            lg = int(batch[f"{name}_lengths"][i])
                            data = data[:lg]
                        datadir_writer[f"{name}_shape"][key] = ",".join(
                            map(str, data.shape)
                        )

                # 2. Extract feats
                if ngpu <= 1:
                    data = model.collect_feats(**batch)
                else:
                    # Note that data_parallel can parallelize only "forward()"
                    data = data_parallel(
                        ForwardAdaptor(model, "collect_feats"),
                        (),
                        range(ngpu),
                        module_kwargs=batch,
                    )

                # 3. Calculate sum and square sum
                for key, v in data.items():
                    for i, (uttid, seq) in enumerate(zip(keys, v.cpu().numpy())):
                        # Truncate zero-padding region
                        if f"{key}_lengths" in data:
                            length = data[f"{key}_lengths"][i]
                            # seq: (Length, Dim, ...)
                            seq = seq[:length]
                        else:
                            # seq: (Dim, ...) -> (1, Dim, ...)
                            seq = seq[None]
                        # Accumulate value, its square, and count
                        sum_dict[key] += seq.sum(0)
                        sq_dict[key] += (seq ** 2).sum(0)
                        count_dict[key] += len(seq)

                        # 4. [Option] Write derived features as npy format file.
                        if write_collected_feats:
                            # Instantiate NpyScpWriter for the first iteration
                            if (key, mode) not in npy_scp_writers:
                                p = output_dir / mode / "collect_feats"
                                npy_scp_writers[(key, mode)] = NpyScpWriter(
                                    p / f"data_{key}", p / f"{key}.scp"
                                )
                            # Save array as npy file
                            npy_scp_writers[(key, mode)][uttid] = seq

                if iiter % log_interval == 0:
                    logging.info(f"Niter: {iiter}")

        for key in sum_dict:
            np.savez(
                output_dir / mode / f"{key}_stats.npz",
                count=count_dict[key],
                sum=sum_dict[key],
                sum_square=sq_dict[key],
            )

        # batch_keys and stats_keys are used by aggregate_stats_dirs.py
        with (output_dir / mode / "batch_keys").open("w", encoding="utf-8") as f:
            f.write(
                "\n".join(filter(lambda x: not x.endswith("_lengths"), batch)) + "\n"
            )
        with (output_dir / mode / "stats_keys").open("w", encoding="utf-8") as f:
            f.write("\n".join(sum_dict) + "\n")
